{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DweYe9FcbMK_"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "AVV2e0XKbJeX"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sUtoed20cRJJ"
      },
      "source": [
        "# tf.data を使って CSV をロードする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ap_W4aQcgNT"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/load_data/csv\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/load_data/csv.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/load_data/csv.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/load_data/csv.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RSywPQ2n736s"
      },
      "source": [
        "Note: これらのドキュメントは私たちTensorFlowコミュニティが翻訳したものです。コミュニティによる 翻訳は**ベストエフォート**であるため、この翻訳が正確であることや[英語の公式ドキュメント](https://www.tensorflow.org/?hl=en)の 最新の状態を反映したものであることを保証することはできません。 この翻訳の品質を向上させるためのご意見をお持ちの方は、GitHubリポジトリ[tensorflow/docs](https://github.com/tensorflow/docs)にプルリクエストをお送りください。 コミュニティによる翻訳やレビューに参加していただける方は、 [docs-ja@tensorflow.org メーリングリスト](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ja)にご連絡ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C-3Xbt0FfGfs"
      },
      "source": [
        "このチュートリアルでは、CSV データを `tf.data.Dataset` にロードする手法の例を示します。\n",
        "\n",
        "このチュートリアルで使われているデータはタイタニック号の乗客リストから取られたものです。乗客が生き残る可能性を、年齢、性別、チケットの等級、そして乗客が一人で旅行しているか否かといった特性から予測することを試みます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fgZ9gjmPfSnK"
      },
      "source": [
        "## 設定"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "baYFZMW_bJHh"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ncf5t6tgL5ZI"
      },
      "outputs": [],
      "source": [
        "TRAIN_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\"\n",
        "TEST_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/eval.csv\"\n",
        "\n",
        "train_file_path = tf.keras.utils.get_file(\"train.csv\", TRAIN_DATA_URL)\n",
        "test_file_path = tf.keras.utils.get_file(\"eval.csv\", TEST_DATA_URL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4ONE94qulk6S"
      },
      "outputs": [],
      "source": [
        "# numpy の値を読みやすくする\n",
        "np.set_printoptions(precision=3, suppress=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wuqj601Qw0Ml"
      },
      "source": [
        "## データのロード\n",
        "\n",
        "それではいつものように、扱っている CSV ファイルの先頭を見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "54Dv7mCrf9Yw"
      },
      "outputs": [],
      "source": [
        "!head {train_file_path}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YOYKQKmMj3D6"
      },
      "source": [
        "[pandasでデータを読み込んでから](pandas_dataframe.ipynb) Numpy の array を TensorFlow に渡すというのは可能です。もし、大規模なデータに対応するためにスケールアップしたり、データの読み込み処理を [TensorFlow や tf.data](../../guide/data.ipynb) に統合されたものにしたい場合、`tf.data.experimental.make_csv_dataset` 関数が利用できます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "67mfwr4v-mN_"
      },
      "source": [
        "The only column you need to identify explicitly is the one with the value that the model is intended to predict. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iXROZm5f3V4E"
      },
      "outputs": [],
      "source": [
        "LABELS = [0, 1]\n",
        "LABEL_COLUMN = 'survived'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t4N-plO4tDXd"
      },
      "source": [
        "コンストラクタの引数の値が揃ったので、ファイルから CSV データを読み込みデータセットを作ることにしましょう。\n",
        "\n",
        "（完全なドキュメントは、`tf.data.experimental.make_csv_dataset` を参照してください）"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Co7UJ7gpNADC"
      },
      "outputs": [],
      "source": [
        "def get_dataset(file_path, **kwargs):\n",
        "  dataset = tf.data.experimental.make_csv_dataset(\n",
        "      file_path,\n",
        "      batch_size=5, # Artificially small to make examples easier to show.\n",
        "      label_name=LABEL_COLUMN,\n",
        "      na_value=\"?\",\n",
        "      num_epochs=1,\n",
        "      ignore_errors=True, \n",
        "      **kwargs)\n",
        "  return dataset\n",
        "\n",
        "raw_train_data = get_dataset(train_file_path)\n",
        "raw_test_data = get_dataset(test_file_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IHYHy_dN8lmd"
      },
      "outputs": [],
      "source": [
        "def show_batch(dataset):\n",
        "  for batch, label in dataset.take(1):\n",
        "    for key, value in batch.items():\n",
        "      print(\"{:20s}: {}\".format(key,value.numpy()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vHUQFKoQI6G7"
      },
      "source": [
        "データセットを構成する要素は、(複数のサンプル, 複数のラベル)の形のタプルとして表されるバッチです。サンプル中のデータは（行ベースのテンソルではなく）列ベースのテンソルとして構成され、それぞれはバッチサイズ（このケースでは5個）の要素が含まれます。\n",
        "\n",
        "実際に見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qWtFYtwXIeuj"
      },
      "outputs": [],
      "source": [
        "show_batch(raw_train_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aVubDtIX8lml"
      },
      "source": [
        "上で見たように、この CSV の列には名前がついています。Dataset のコンストラクターはこれらの列名を自動的に抽出します。一行目に列名が記されていない CSV を扱う場合には、列名のリストを `make_csv_dataset` 関数の `column_names` 引数に渡してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uu1YvDy18lmm"
      },
      "outputs": [],
      "source": [
        "CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']\n",
        "\n",
        "temp_dataset = get_dataset(train_file_path, column_names=CSV_COLUMNS)\n",
        "\n",
        "show_batch(temp_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uMghf1Gd8lmp"
      },
      "source": [
        "このサンプルでは利用できるすべての列を使います。もし、データセットから特定の列を除外したい場合には、利用したい列名のリストを作成し、コンストラクタのオプショナルな引数である `select_columns` にそのリストを渡してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3hGJIu788lmp"
      },
      "outputs": [],
      "source": [
        "SELECT_COLUMNS = ['survived', 'age', 'n_siblings_spouses', 'class', 'deck', 'alone']\n",
        "\n",
        "temp_dataset = get_dataset(train_file_path, select_columns=SELECT_COLUMNS)\n",
        "\n",
        "show_batch(temp_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9cryz31lxs3e"
      },
      "source": [
        "## データの前処理\n",
        "\n",
        "CSV ファイルに含まれるデータには様々な型のものがありえます。データをモデルに入力する前に、典型的なケースでは、様々な型を含んだデータを固定長のベクトルに変換する必要があります。\n",
        "\n",
        "TensorFlow には典型的なデータ変換を行う `tf.feature_column` が組み込まれています。詳細は[このチュートリアル](../keras/feature_columns)を参照してください。\n",
        "\n",
        "また、任意のツール (たとえば、[nltk](https://www.nltk.org/) や [sklearn](https://scikit-learn.org/stable/) ) を使って前処理を行い、前処理した結果を TensorFlow にわたすこともできます。\n",
        "\n",
        "モデルの内部で前処理を行う最大の利点は、モデルをエクスポートしたときにその中に前処理が含まれていることです。この方法を取れば、生データをそのままモデルに入力することができます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9AsbaFmCeJtF"
      },
      "source": [
        "### 連続データ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o2maE8d2ijsq"
      },
      "source": [
        "すでにデータが適切に数値型になっている場合、モデルに渡す前にそのデータをベクトル形式にできます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IwGOy61lkQw-"
      },
      "outputs": [],
      "source": [
        "SELECT_COLUMNS = ['survived', 'age', 'n_siblings_spouses', 'parch', 'fare']\n",
        "DEFAULTS = [0, 0.0, 0.0, 0.0, 0.0]\n",
        "temp_dataset = get_dataset(train_file_path, \n",
        "                           select_columns=SELECT_COLUMNS,\n",
        "                           column_defaults = DEFAULTS)\n",
        "\n",
        "show_batch(temp_dataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v9GRXQKP8lmx"
      },
      "outputs": [],
      "source": [
        "example_batch, labels_batch = next(iter(temp_dataset)) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Yh8R7BujTAu"
      },
      "source": [
        "次のサンプルは、数値型の列をベクトルに変換するシンプルな関数の例です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iNE_mTJqegGQ"
      },
      "outputs": [],
      "source": [
        "def pack(features, label):\n",
        "  return tf.stack(list(features.values()), axis=-1), label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "raZtRlmaj-A5"
      },
      "source": [
        "この関数をデータセットのそれぞれの要素に適用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G-t_RSBrM2Vm"
      },
      "outputs": [],
      "source": [
        "packed_dataset = temp_dataset.map(pack)\n",
        "\n",
        "for features, labels in packed_dataset.take(1):\n",
        "  print(features.numpy())\n",
        "  print()\n",
        "  print(labels.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M9lMLaEsjq3K"
      },
      "source": [
        "様々な型を含んだデータセットを利用する場合、シンプルな数値型のフィールドだけを分離したくなるかもしれません。`tf.feature_column` はそのような処理を実現できますが、いくらかのオーバーヘッドを発生させるため、それが本当に必要になる場合以外ではできるだけ避けるべきでしょう。様々な型を含んだデータセットに話を戻しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eb7XlwSv8lnA"
      },
      "outputs": [],
      "source": [
        "show_batch(raw_train_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-sOMGWvE8lnD"
      },
      "outputs": [],
      "source": [
        "example_batch, labels_batch = next(iter(temp_dataset))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z83C3YRy8lnG"
      },
      "source": [
        "数値型の特徴量を選んで、それらをベクトル化して単一の列に変換する、より一般的な前処理器を作成しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FP3iiwvB8lnG"
      },
      "outputs": [],
      "source": [
        "class PackNumericFeatures(object):\n",
        "  def __init__(self, names):\n",
        "    self.names = names\n",
        "\n",
        "  def __call__(self, features, labels):\n",
        "    numeric_features = [features.pop(name) for name in self.names]\n",
        "    numeric_features = [tf.cast(feat, tf.float32) for feat in numeric_features]\n",
        "    numeric_features = tf.stack(numeric_features, axis=-1)\n",
        "    features['numeric'] = numeric_features\n",
        "\n",
        "    return features, labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YH4pGR8B8lnJ"
      },
      "outputs": [],
      "source": [
        "NUMERIC_FEATURES = ['age','n_siblings_spouses','parch', 'fare']\n",
        "\n",
        "packed_train_data = raw_train_data.map(\n",
        "    PackNumericFeatures(NUMERIC_FEATURES))\n",
        "\n",
        "packed_test_data = raw_test_data.map(\n",
        "    PackNumericFeatures(NUMERIC_FEATURES))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "517mwDbe8lnN"
      },
      "outputs": [],
      "source": [
        "show_batch(packed_train_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uG8pIl3P8lnR"
      },
      "outputs": [],
      "source": [
        "example_batch, labels_batch = next(iter(packed_train_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BEAHg2wG8lnV"
      },
      "source": [
        "#### データの正規化\n",
        "\n",
        "連続値をとるデータは常に正規化されるべきでしょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XPfSfaDf8lnW"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "desc = pd.read_csv(train_file_path)[NUMERIC_FEATURES].describe()\n",
        "desc"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bULc6B1H8lnZ"
      },
      "outputs": [],
      "source": [
        "MEAN = np.array(desc.T['mean'])\n",
        "STD = np.array(desc.T['std'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d779EGJk8lnc"
      },
      "outputs": [],
      "source": [
        "def normalize_numeric_data(data, mean, std):\n",
        "  # Center the data\n",
        "  return (data-mean)/std\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N_mmQoYR8lnk"
      },
      "source": [
        "数値型の列を作成しましょう。`tf.feature_columns.numeric_column` では `normalizer_fn` 引数で正則化のための関数を受け取ります。また、渡された関数はそれぞれのバッチに対して実行されます。\n",
        "\n",
        "[`functools.partial`](https://docs.python.org/3/library/functools.html#functools.partial) を用いて関数を部分適用し、`normalize_numeric_data` の引数 `MEAN` と `STD` を固定しておきましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v_qY964o8lnk"
      },
      "outputs": [],
      "source": [
        "# See what you just created.\n",
        "normalizer = functools.partial(normalize_numeric_data, mean=MEAN, std=STD)\n",
        "\n",
        "numeric_column = tf.feature_column.numeric_column('numeric', normalizer_fn=normalizer, shape=[len(NUMERIC_FEATURES)])\n",
        "numeric_columns = [numeric_column]\n",
        "numeric_column"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nVZ4L1Xb8lnr"
      },
      "source": [
        "モデルの訓練時には、数値データのブロックを選び出して利用できるように、作成した特徴量の列を入力に含めましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "abISAWXw8lns"
      },
      "outputs": [],
      "source": [
        "example_batch['numeric']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QCLdWKyV8lnw"
      },
      "outputs": [],
      "source": [
        "numeric_layer = tf.keras.layers.DenseFeatures(numeric_columns)\n",
        "numeric_layer(example_batch).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EHOJtcdx8ln0"
      },
      "source": [
        "ここで用いられた、平均を用いた正規化を利用するためには、それぞれの列の平均値を事前に知っている必要があります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tSyrkSQwYHKi"
      },
      "source": [
        "### カテゴリデータ\n",
        "\n",
        "この CSV データ中のいくつかの列はカテゴリ列です。つまり、その中身は、限られた選択肢の中のひとつである必要があります。\n",
        "\n",
        "`tf.feature_column` API を用いて、それぞれのカテゴリ列のためのコレクションを `tf.feature_column.indicator_column` で生成しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mWDniduKMw-C"
      },
      "outputs": [],
      "source": [
        "CATEGORIES = {\n",
        "    'sex': ['male', 'female'],\n",
        "    'class' : ['First', 'Second', 'Third'],\n",
        "    'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],\n",
        "    'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],\n",
        "    'alone' : ['y', 'n']\n",
        "}\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f1RrrR1N8loB"
      },
      "outputs": [],
      "source": [
        "categorical_columns = []\n",
        "for feature, vocab in CATEGORIES.items():\n",
        "  cat_col = tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "        key=feature, vocabulary_list=vocab)\n",
        "  categorical_columns.append(tf.feature_column.indicator_column(cat_col))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LzveBLUZ8loE"
      },
      "outputs": [],
      "source": [
        "# See what you just created.\n",
        "categorical_columns"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FFn2BNqn8loI"
      },
      "outputs": [],
      "source": [
        "categorical_layer = tf.keras.layers.DenseFeatures(categorical_columns)\n",
        "print(categorical_layer(example_batch).numpy()[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ivic2l6N8loM"
      },
      "source": [
        "モデルを構築する際、これは入力レイヤーで行われるデータ処理の一部になります。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kPWkC4_1l3IG"
      },
      "source": [
        "### 一体化した前処理レイヤー"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jIvyqVAXmsN4"
      },
      "source": [
        "2つの特徴量列のコレクションを結合し、`tf.keras.layers.DenseFeatures` に入力して、両方のデータ型を読み込んで前処理する入力レイヤーを作成しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "__o_rUqX8loO"
      },
      "outputs": [],
      "source": [
        "preprocessing_layer = tf.keras.layers.DenseFeatures(categorical_columns+numeric_columns)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gBQf3r718loR"
      },
      "outputs": [],
      "source": [
        "print(preprocessing_layer(example_batch).numpy()[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DlF_omQqtnOP"
      },
      "source": [
        "## モデルの構築"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lQoFh16LxtT_"
      },
      "source": [
        "`preprocessing_layer` から始まる `tf.keras.Sequential` を構築しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JDM3FIgHNCW3"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "  preprocessing_layer,\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(1),\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hPdtI2ie0lEZ"
      },
      "source": [
        "## 訓練、評価、そして予測"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8gvw1RE9zXkD"
      },
      "source": [
        "これでモデルをインスタンス化し、訓練することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q_nm28IzNDTO"
      },
      "outputs": [],
      "source": [
        "train_data = packed_train_data.shuffle(500)\n",
        "test_data = packed_test_data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ehFi8iUd8loe"
      },
      "outputs": [],
      "source": [
        "model.fit(train_data, epochs=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QyDMgBurzqQo"
      },
      "source": [
        "モデルの訓練が終わったら、`test_data` データセットでの正解率をチェックできます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eB3R3ViVONOp"
      },
      "outputs": [],
      "source": [
        "test_loss, test_accuracy = model.evaluate(test_data)\n",
        "\n",
        "print('\\n\\nTest Loss {}, Test Accuracy {}'.format(test_loss, test_accuracy))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sTrn_pD90gdJ"
      },
      "source": [
        "単一のバッチ、または、バッチからなるデータセットのラベルを推論する場合には、`tf.keras.Model.predict` を使います。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qwcx74F3ojqe"
      },
      "outputs": [],
      "source": [
        "predictions = model.predict(test_data)\n",
        "\n",
        "# 結果のいくつかを表示\n",
        "for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):\n",
        "  print(\"Predicted survival: {:.2%}\".format(prediction[0]),\n",
        "        \" | Actual outcome: \",\n",
        "        (\"SURVIVED\" if bool(survived) else \"DIED\"))\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "csv.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
